import sys
from BasicLibraries import *
import pickle
import Functions.Miscellaneous.utils as ut
import allocation_ga as af
from sklearn.inspection import permutation_importance
import os

AM = af.AllocationGAD()



ONEDRIVE = ''
initiatives = ut.utils().return_Initiatives()
LCGreen, LCYellow, LCBlue = ut.utils().get_LCColors()

#================
np.random.seed(5)
#================
#=== Get year range
if os.path.exists('0_5_LandscapePopulation_0724_full.txt'):
    Z = pd.read_csv('0_5_LandscapePopulation_0724_full.txt')
    max_year = 2021
    min_year =max_year-4
else:
    t=0
#===


#%%
def get_feature_importance_(dt, I, env_sdgs, normal_fit, population_type, 
                            feature_type, 
                            permutation_test_type, 
                            final_year = 2021,
                            save_output=True,
                            energy_control=False):
  
    dt = dt[dt.next_year_performance > -3]
    #==
    dt = dt.groupby('mrg').last()
    s = dt.copy()
    s = s[(s.rfyear >= 2015)]
    idx = s[['gvkey', 'rfyear']].groupby('gvkey').count()
    idx = idx[idx >= 3].dropna()

    #== 
    gvkey_list = list(np.unique(dt[dt.gvkey.isin(idx.index)].gvkey))
    print('Number of firms to analyse:', len(gvkey_list))
    sdg_universe = 'Environment'
    
    max_year =final_year+1
    min_year =max_year-4
        
    years_range = [n for n in range(min_year, max_year)]
        
    #==

    cols = len(I)
    rows = len(env_sdgs)
    bdims = cols*rows
    tdims = bdims+6
    #==

    average_importance_factors =  []
    feature_importance = dict()

    #== Only for impurity importance
    aM, aS, sM, sS = pd.DataFrame(), pd.DataFrame(),pd.DataFrame(), pd.DataFrame()


    for Y in years_range:
        print('Year', Y)
        #=== No need of explicit resampling here
        dat = dt[dt.rfyear <= Y].copy()
        x, y, forst, sdgs, controls, names, reduced_gvkey_list, _, _ = AM.fit_model(dat,
                                                          sdgs_number=env_sdgs,
                                                          dimension_ = I, \
                                                          gvkey_list = gvkey_list,
                                                          model_type = 'RandomForest',
                                                          normal_fit=normal_fit,
                                                          energy_price_control = energy_control)    
        bdims = len(sdgs)
        #== 
        ft_importance = permutation_importance(forst, np.array(x), np.array(y), n_repeats=10, random_state=5, n_jobs=4)
        forest_importances = pd.Series(ft_importance.importances_mean, index=x.columns)
        forest_importances = forest_importances/forest_importances.sum()
        tdims_ = bdims+6
        
        #=== The problem is that year fixed effect change o na rolling basis, so you 
        #=== Want to calculate the feature importance over the constant features
        constant_scaler = forest_importances[:tdims_].sum()
        average_importance_factors.append([Y, forest_importances.iloc[:bdims].sum()/constant_scaler, 
                                              forest_importances.iloc[bdims:tdims_].sum()/constant_scaler,
                                              np.nan])
        feature_importance[str(Y)] = AM.make_importance_matrix(pd.DataFrame(forest_importances.iloc[:bdims]))/constant_scaler

    average_importance_factors = pd.DataFrame(average_importance_factors, columns = ['rfyear', 'behaviour', 'financial', 'region'])

    if save_output:
        print('Saving Feature Importance Output')
        pickle.dump([average_importance_factors, 
                     aM, aS, sM, sS, 
                     feature_importance], open(feature_type, 'wb')) 
    else:
        print('Return Feature Importance Output')
        return (average_importance_factors, aM, aS, sM, sS, feature_importance, years_range, subset_gvkey)

#===============
def get_importance_table(I, env_sdgs, normal_fit, population_type):
    dt  =pd.read_csv(population_type, low_memory= False)      
    dt = dt[dt.next_year_performance > -3]
    #==
    dt = dt.groupby('mrg').last()
    s = dt.copy()
    s = s[(s.rfyear >= 2015)]
    idx = s[['gvkey', 'rfyear']].groupby('gvkey').count()
    idx = idx[idx >= 3].dropna()

    #== 
    gvkey_list = list(np.unique(dt[dt.gvkey.isin(idx.index)].gvkey))
    print('Number of firms to analyse:', len(gvkey_list))
    sdg_universe = 'Environment'

        
    #==
    x, y, forst, sdgs, controls, names, reduced_gvkey_list, _, _ = AM.fit_model(dt,  
                                                      sdgs_number=env_sdgs,
                                                      dimension_ = I, \
                                                      gvkey_list = gvkey_list,
                                                      model_type = 'RandomForest',
                                                      normal_fit=normal_fit)  

    print('\nEstimating permutation importance from the whole sample')
    ft_importance = permutation_importance(forst, np.array(x), np.array(y), n_repeats=20, random_state=42, n_jobs=2)
    g = []
    for i in range(len(ft_importance['importances_mean'])):
        a = np.round(ft_importance['importances_mean'][i], 3) - np.round(ft_importance['importances_std'][i], 3)
        if a > 0:            
            s = r'{\bf '+str(np.round(ft_importance['importances_mean'][i], 3))+r'$\pm$'+\
                str(np.round(ft_importance['importances_std'][i],3))+'}'
        else:
            s = '\textcolor{gray}{'+str(np.round(ft_importance['importances_mean'][i], 3))+r'$\pm$'+\
                str(np.round(ft_importance['importances_std'][i],3))+'}'
    
        g.append([x.columns[i], s])
        
    g = pd.DataFrame(g, columns = ['Feature', 'Importance'])
    g = g.iloc[:24]
    g['Action'] = g['Feature'].apply(lambda x: x.split(' - ')[0]) 
    g['SDG'] = g['Feature'].apply(lambda x: x.split(' - ')[1]) 
    g = g[['Action', 'SDG', 'Importance']].pivot('Action', 'SDG', 'Importance')
    print(g.to_latex(escape = False))
    return g

#===============
def plot_feature_importance(feature_type, min_year, max_year, save_plot=False):
    #==
    average_importance_factors, \
    aM, aS, sM, sS, \
    feature_importance = pickle.load(open(feature_type, 'rb'))

    #==
    M = average_importance_factors.groupby('rfyear').mean()
    M.columns = ['Behavioural', 'Financial', 'Regional']
    if save_plot:
        M.plot.bar(stacked=True, color = [LCGreen, 'RoyalBlue', 'lightcoral'], rot = 0)
        plt.xlabel('')
        plt.ylabel('Relative importance of features')
        plt.legend(loc = 'center', bbox_to_anchor = (0.5,1.1), ncol=3)
    if save_plot:
        plt.savefig('Figures/Fig_FI.pdf', bbox_inches = 'tight', dpi = 100)
        
    #==
    feature_baseline = pd.DataFrame(0, columns = feature_importance[str(2018)].columns, index = feature_importance[str(2018)].index)
    for Y in range(min_year, max_year):
        feature_importance[str(Y)] = feature_baseline + feature_importance[str(Y)]
    u= pd.DataFrame(np.mean([feature_importance[str(Y)] for Y  in range(min_year, max_year)], axis = 0), 
                    columns = feature_importance[str(max_year -1)].columns, 
                    index = feature_importance[str(max_year -1)].index)
    s= pd.DataFrame(np.std([feature_importance[str(Y)] for Y  in range(min_year, max_year)], axis = 0), 
                    columns = feature_importance[str(max_year -1)].columns, 
                    index = feature_importance[str(max_year -1)].index)

    if save_plot:
        plt.figure(figsize = (16, 8))
        u.loc[:, 'Total'] = u.sum(axis = 1)
        u.loc['Total', :] = u.sum(axis = 0)
        u = u.abs()
        plt.title('Features importance, %', y= 1.05)
        cm = 'RdYlBu' 
        ax = sns.heatmap(np.round(u*100, 1).transpose(), annot = True, cbar =  False, fmt  = '.2g', 
                         vmin = 0,vmax=2,
                         center = 1, linewidths = 0.8, 
                         annot_kws={"color": "k", 'size': '24'},\
                             linecolor = 'black', alpha = 0.6,
                             cmap = cm)
        plt.xlabel('')
        plt.ylabel('')
        ax.yaxis.tick_right()
        ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize = 24)
    if save_plot:
        plt.savefig('Figures/Fig_FI_Mat.pdf', bbox_inches = 'tight', dpi = 100)
    return M, u,s, average_importance_factors
def barplot_feature_incidence(A, B, save_plot):
    plt.figure(figsize = (18,6))
    types =['sdg', 'ini']
    count=1
    for i in [A, B]:
        i = i.drop(index = ['Total'])
        ax = plt.subplot(int(str(12)+str(count)))
        i.plot.bar(color='orange', ax = ax)
        #plt.xticks(i.index)
        plt.xlabel('')
        if count==1:
            plt.ylabel(r'Importance/incidence')
        else: plt.ylabel('')
        plt.axhline(y=1, ls= '-.', c = 'k', lw=1)
        count+=1
    if save_plot:
        plt.savefig('Figures/Fig_importance_incidence.pdf', bbox_inches = 'tight', dpi = 100)

def plot_feature_importance_versus_incidence(empirical_, population_type, actions, sdgs, feature_type, feature_importance = 0, plot_figure=True, save_plot=False):
    
    #==
    if feature_importance == 0:
        _, _, _, _, _, \
            feature_importance = pickle.load(open(feature_type, 'rb'))
    
    #==
    u= pd.DataFrame(np.mean([feature_importance[str(Y)] for Y  in list(feature_importance.keys())], axis = 0), 
                    columns = feature_importance[list(feature_importance.keys())[-1]].columns, 
                    index = feature_importance[list(feature_importance.keys())[-1]].index)
    
    u = u/u.sum().sum()
    u.loc[:, 'Total'] = u.sum(axis = 1)
    u.loc['Total', :] = u.sum(axis = 0)

    #===
    A = u.loc[:, 'Total']/empirical_.loc[:, 'Total']
    B = u.loc['Total',:]/empirical_.loc['Total', :]
    
    #====

    if plot_figure:
        barplot_feature_incidence(A, B, save_plot)
        
    return A, B
    
    
    
